package com.smt.config;

import com.smt.filter.PostGatewayFilterFactory;
import com.smt.filter.PreGatewayFilterFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.cloud.client.discovery.DiscoveryClient;
import org.springframework.cloud.gateway.discovery.DiscoveryClientRouteDefinitionLocator;
import org.springframework.cloud.gateway.discovery.DiscoveryLocatorProperties;
import org.springframework.cloud.gateway.filter.ratelimit.KeyResolver;
import org.springframework.cloud.gateway.filter.ratelimit.RedisRateLimiter;
import org.springframework.cloud.gateway.route.RouteDefinitionLocator;
import org.springframework.cloud.gateway.route.RouteLocator;
import org.springframework.cloud.gateway.route.builder.RouteLocatorBuilder;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.web.cors.reactive.CorsUtils;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Mono;

/**
 * 网关配置项
 *
 * @autor xiaoyu.fang
 * @date 2019/5/24 11:08
 */
@Configuration
public class RouteConfiguration {

    @Autowired
    private KeyResolver addressKeyResolver;

    // 这里为支持的请求头，如果有自定义的header字段请自己添加（不知道为什么不能使用*）
    private static final String ALLOWED_HEADERS = "x-requested-with, authorization, Content-Type, Authorization, credential, X-XSRF-TOKEN,token,username,client";
    private static final String ALLOWED_METHODS = "*";
    private static final String ALLOWED_ORIGIN = "*";
    private static final String ALLOWED_Expose = "*";
    private static final String MAX_AGE = "18000L";

    /**
     * 跨域问题
     *
     * @return
     */
    @Bean
    public WebFilter corsFilter() {
        return (ServerWebExchange ctx, WebFilterChain chain) -> {
            ServerHttpRequest request = ctx.getRequest();
            if (CorsUtils.isCorsRequest(request)) {
                ServerHttpResponse response = ctx.getResponse();
                HttpHeaders headers = response.getHeaders();
                headers.add("Access-Control-Allow-Origin", ALLOWED_ORIGIN);
                headers.add("Access-Control-Allow-Methods", ALLOWED_METHODS);
                headers.add("Access-Control-Max-Age", MAX_AGE);
                headers.add("Access-Control-Allow-Headers", ALLOWED_HEADERS);
                headers.add("Access-Control-Expose-Headers", ALLOWED_Expose);
                headers.add("Access-Control-Allow-Credentials", "true");
                if (request.getMethod() == HttpMethod.OPTIONS) {
                    response.setStatusCode(HttpStatus.OK);
                    return Mono.empty();
                }
            }
            return chain.filter(ctx);
        };
    }

    /**
     * 如果使用了注册中心（如：Eureka），进行控制则需要增加如下配置
     * yml已配置
     */
    @Bean
    public RouteDefinitionLocator discoveryClientRouteDefinitionLocator(DiscoveryClient discoveryClient, DiscoveryLocatorProperties properties) {
        return new DiscoveryClientRouteDefinitionLocator(discoveryClient, properties);
    }

    /**
     * 路由配置
     */
    @Bean
    public RouteLocator smpRoutes(RouteLocatorBuilder builder) {
        return builder.routes()
                .route(r -> r.weight("provide", 90).and().path("/smt-service-sys/**")
                        // .and().method("POST")
                        // .and().query("name", "xxx")
                        // .and().after(ZonedDateTime.now())
                        // .and().between()
                        // .and().before(ZonedDateTime.now())
                        // .and().cookie("cookie","xxx")
                        // .and().header("header", "xxxx")
                        // .and().host("127.0.0.1", "192.168.0.1")
                        // retry：重复次数; 前置过滤器; 后置过滤器
                        .filters(f -> f.retry(3).filter(new PreGatewayFilterFactory().apply()).filter(new PostGatewayFilterFactory().apply())
                                // 令牌桶的容量，允许在一秒钟内完成的最大请求数、允许用户每秒处理多少个请求
                                .requestRateLimiter().rateLimiter(RedisRateLimiter.class, config -> config.setBurstCapacity(20).setReplenishRate(10))
                                // 限流
                                .configure(config -> config.setKeyResolver(addressKeyResolver)).stripPrefix(1)
                                // 熔断机制，"/fallback"表示跳转的路径
                                .hystrix(config -> config.setFallbackUri("forward:/fallback").setName("fallbackcmd"))
                                // HTTP 的状态返回码，取值请参考：org.springframework.http.HttpStatus
                                // .setStatus(HttpStatus.BAD_GATEWAY)
                                // 设置表头
                                .addRequestHeader("X-Request-Foo", "SmallPlume"))
                        .uri("lb://smt-service-sys"))
                .route(r -> r.weight("provide", 80).and().path("/smt-service-bus/**")
                        .filters(f -> f.retry(3).addRequestHeader("X-Request-Foo", "SmallPlume").stripPrefix(1))
                        .uri("lb://smt-service-business"))
                .build();
    }

    /**
     * 限流
     *
     * @return
     */
    @Bean(name = "remoteAddrKeyResolver")
    public KeyResolver remoteAddrKeyResolver() {
        return exchange -> Mono.just(exchange.getRequest().getRemoteAddress().getAddress().getHostAddress());
    }

    /**
     * 前置过滤器
     *
     * @return
     */
    @Bean(name = "pre")
    public PreGatewayFilterFactory preGatewayFilterFactory() {
        return new PreGatewayFilterFactory();
    }

    /**
     * 后置过滤器
     *
     * @return
     */
    @Bean(name = "post")
    public PostGatewayFilterFactory postGatewayFilterFactory() {
        return new PostGatewayFilterFactory();
    }

}
